{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "72f0b224",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/low_level/router.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "645ef724-bba9-4d9d-89b0-a470dbfb1713",
   "metadata": {},
   "source": [
    "# Building a Router from Scratch\n",
    "\n",
    "In this tutorial, we show you how to build an LLM-powered router module that can route a user query to submodules.\n",
    "\n",
    "Routers are a simple but effective form of automated decision making that can allow you to perform dynamic retrieval/querying over your data.\n",
    "\n",
    "In LlamaIndex, this is abstracted away with our [Router Modules](https://gpt-index.readthedocs.io/en/latest/core_modules/query_modules/router/root.html).\n",
    "\n",
    "To build a router, we'll walk through the following steps:\n",
    "- Crafting an initial prompt to select a set of choices\n",
    "- Enforcing structured output (for text completion endpoints)\n",
    "- Try integrating with a native function calling endpoint.\n",
    "\n",
    "And then we'll plug this into a RAG pipeline to dynamically make decisions on QA vs. summarization."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac3b2770-a695-48dc-88e8-371cf0a7d7d0",
   "metadata": {},
   "source": [
    "## 1. Setup a Basic Router Prompt\n",
    "\n",
    "At its core, a router is a module that takes in a set of choices. Given a user query, it \"selects\" a relevant choice.\n",
    "\n",
    "For simplicity, we'll start with the choices as a set of strings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d8342c9-0f43-4898-9abd-96c92c177ac3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index import PromptTemplate\n",
    "\n",
    "choices = [\n",
    "    \"Useful for questions related to apples\",\n",
    "    \"Useful for questions related to oranges\",\n",
    "]\n",
    "\n",
    "\n",
    "def get_choice_str(choices):\n",
    "    choices_str = \"\\n\\n\".join(\n",
    "        [f\"{idx+1}. {c}\" for idx, c in enumerate(choices)]\n",
    "    )\n",
    "    return choices_str\n",
    "\n",
    "\n",
    "choices_str = get_choice_str(choices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a2da874-9176-4eec-b3fc-f22cb64bb6f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "router_prompt0 = PromptTemplate(\n",
    "    \"Some choices are given below. It is provided in a numbered list (1 to\"\n",
    "    \" {num_choices}), where each item in the list corresponds to a\"\n",
    "    \" summary.\\n---------------------\\n{context_list}\\n---------------------\\nUsing\"\n",
    "    \" only the choices above and not prior knowledge, return the top choices\"\n",
    "    \" (no more than {max_outputs}, but only select what is needed) that are\"\n",
    "    \" most relevant to the question: '{query_str}'\\n\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe20cf42-8f36-4b65-a4c2-3f0bd6895383",
   "metadata": {},
   "source": [
    "Let's try this prompt on a set of toy questions and see what the output brings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "929a238e-1271-4a18-9f3c-da8a9cd9b5e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.llms import OpenAI\n",
    "\n",
    "llm = OpenAI(model=\"gpt-3.5-turbo\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8367510f-91e5-4509-a228-4becdb222edc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_formatted_prompt(query_str):\n",
    "    fmt_prompt = router_prompt0.format(\n",
    "        num_choices=len(choices),\n",
    "        max_outputs=2,\n",
    "        context_list=choices_str,\n",
    "        query_str=query_str,\n",
    "    )\n",
    "    return fmt_prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64c8fc9c-5fdc-4ab6-90c5-f36031420d01",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_str = \"Can you tell me more about the amount of Vitamin C in apples\"\n",
    "fmt_prompt = get_formatted_prompt(query_str)\n",
    "response = llm.complete(fmt_prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47800333-92ea-4cb3-8e11-2d795a635fdc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1. Useful for questions related to apples\n"
     ]
    }
   ],
   "source": [
    "print(str(response))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c111863-ea12-409c-9777-d3d387126601",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_str = \"What are the health benefits of eating orange peels?\"\n",
    "fmt_prompt = get_formatted_prompt(query_str)\n",
    "response = llm.complete(fmt_prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fc70519-1805-4899-a4d4-a22cd49847c0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2. Useful for questions related to oranges\n"
     ]
    }
   ],
   "source": [
    "print(str(response))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "100f6664-3bef-4f5a-89e8-6c66a5d687b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_str = (\n",
    "    \"Can you tell me more about the amount of Vitamin C in apples and oranges.\"\n",
    ")\n",
    "fmt_prompt = get_formatted_prompt(query_str)\n",
    "response = llm.complete(fmt_prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1548044f-a0cd-4b3e-8678-4e49a97acdbb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1. Useful for questions related to apples\n",
      "2. Useful for questions related to oranges\n"
     ]
    }
   ],
   "source": [
    "print(str(response))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb39f574-af06-495a-83e0-50866aca9caf",
   "metadata": {},
   "source": [
    "**Observation**: While the response corresponds to the correct choice, it can be hacky to parse into a structured output (e.g. a single integer). We'd need to do some string parsing on the choices to extract out a single number, and make it robust to failure modes."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7a4cfd8-a515-4613-9f43-92563c14c846",
   "metadata": {},
   "source": [
    "## 2. A Router Prompt that can generate structured outputs\n",
    "\n",
    "Therefore the next step is to try to prompt the model to output a more structured representation (JSON). \n",
    "\n",
    "We define an output parser class (`RouterOutputParser`). This output parser will be responsible for both formatting the prompt and also parsing the result into a structured object (an `Answer`).\n",
    "\n",
    "We then apply the `format` and `parse` methods of the output parser around the LLM call using the router prompt to generate a structured output."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe2e2fb0-aa18-45c1-a8fe-0ff7124eddc1",
   "metadata": {},
   "source": [
    "### 2.a Import Answer Class\n",
    "\n",
    "We load in the Answer class from our codebase. It's a very simple dataclass with two fields: `choice` and `reason`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd0b6408-da38-494e-be00-d2ab0f1a9791",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import fields\n",
    "from pydantic import BaseModel\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ecea17b-bbf8-4a33-9252-dc53c74cbbde",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Answer(BaseModel):\n",
    "    choice: int\n",
    "    reason: str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fa49cf5-b892-4834-a783-0f107cbafcb7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"title\": \"Answer\",\n",
      "  \"type\": \"object\",\n",
      "  \"properties\": {\n",
      "    \"choice\": {\n",
      "      \"title\": \"Choice\",\n",
      "      \"type\": \"integer\"\n",
      "    },\n",
      "    \"reason\": {\n",
      "      \"title\": \"Reason\",\n",
      "      \"type\": \"string\"\n",
      "    }\n",
      "  },\n",
      "  \"required\": [\n",
      "    \"choice\",\n",
      "    \"reason\"\n",
      "  ]\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "print(json.dumps(Answer.schema(), indent=2))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4e23bc5-0fc1-4da5-b5bc-4220f205d0e1",
   "metadata": {},
   "source": [
    "### 2.b Define Router Output Parser"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "585a9a25-0e9f-4da0-aa80-562fd0bb17f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.types import BaseOutputParser"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "166b80df-5736-432c-b49b-03f7450ed524",
   "metadata": {},
   "outputs": [],
   "source": [
    "FORMAT_STR = \"\"\"The output should be formatted as a JSON instance that conforms to \n",
    "the JSON schema below. \n",
    "\n",
    "Here is the output schema:\n",
    "{\n",
    "  \"type\": \"array\",\n",
    "  \"items\": {\n",
    "    \"type\": \"object\",\n",
    "    \"properties\": {\n",
    "      \"choice\": {\n",
    "        \"type\": \"integer\"\n",
    "      },\n",
    "      \"reason\": {\n",
    "        \"type\": \"string\"\n",
    "      }\n",
    "    },\n",
    "    \"required\": [\n",
    "      \"choice\",\n",
    "      \"reason\"\n",
    "    ],\n",
    "    \"additionalProperties\": false\n",
    "  }\n",
    "}\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a089d844-e989-44ea-9908-846bc6e7d584",
   "metadata": {},
   "source": [
    "If we want to put `FORMAT_STR` as part of an f-string as part of a prompt template, then we'll need to escape the curly braces so that they don't get treated as template variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74614e7a-2314-4702-8918-47cd4ec378bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _escape_curly_braces(input_string: str) -> str:\n",
    "    # Replace '{' with '{{' and '}' with '}}' to escape curly braces\n",
    "    escaped_string = input_string.replace(\"{\", \"{{\").replace(\"}\", \"}}\")\n",
    "    return escaped_string"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8657b5d2-ac58-46bd-a3c2-bd0021788814",
   "metadata": {},
   "source": [
    "We now define a simple parsing function to extract out the JSON string from the LLM response (by searching for square brackets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a3c0c59-f72a-4003-b850-58bd81b5bb54",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _marshal_output_to_json(output: str) -> str:\n",
    "    output = output.strip()\n",
    "    left = output.find(\"[\")\n",
    "    right = output.find(\"]\")\n",
    "    output = output[left : right + 1]\n",
    "    return output"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8d8db4d-475a-4fbf-8025-b2eae468b028",
   "metadata": {},
   "source": [
    "We put these together in our `RouterOutputParser`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c04a83fe-624e-48fd-9dc3-f4aa85fb335c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "\n",
    "\n",
    "class RouterOutputParser(BaseOutputParser):\n",
    "    def parse(self, output: str) -> List[Answer]:\n",
    "        \"\"\"Parse string.\"\"\"\n",
    "        json_output = _marshal_output_to_json(output)\n",
    "        json_dicts = json.loads(json_output)\n",
    "        answers = [Answer.from_dict(json_dict) for json_dict in json_dicts]\n",
    "        return answers\n",
    "\n",
    "    def format(self, prompt_template: str) -> str:\n",
    "        return prompt_template + \"\\n\\n\" + _escape_curly_braces(FORMAT_STR)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c20c8fe-9f73-41a7-9760-b2559d475769",
   "metadata": {},
   "source": [
    "### 2.c Give it a Try\n",
    "\n",
    "We create a function called `route_query` that will take in the output parser, llm, and prompt template and output a structured answer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46126685-a108-4404-9307-73de5f51a9fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_parser = RouterOutputParser()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e67dfd6-42a0-4818-8642-6999748843bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "\n",
    "\n",
    "def route_query(\n",
    "    query_str: str, choices: List[str], output_parser: RouterOutputParser\n",
    "):\n",
    "    choices_str\n",
    "\n",
    "    fmt_base_prompt = router_prompt0.format(\n",
    "        num_choices=len(choices),\n",
    "        max_outputs=len(choices),\n",
    "        context_list=choices_str,\n",
    "        query_str=query_str,\n",
    "    )\n",
    "    fmt_json_prompt = output_parser.format(fmt_base_prompt)\n",
    "\n",
    "    raw_output = llm.complete(fmt_json_prompt)\n",
    "    parsed = output_parser.parse(str(raw_output))\n",
    "\n",
    "    return parsed"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e97c0b4b-336e-4550-a4b6-950a141f1b57",
   "metadata": {},
   "source": [
    "## 3. Perform Routing with a Function Calling Endpoint\n",
    "\n",
    "In the previous section, we showed how to build a router with a text completion endpoint. This includes formatting the prompt to encourage the model output structured JSON, and a parse function to load in JSON.\n",
    "\n",
    "This process can feel a bit messy. Function calling endpoints (e.g. OpenAI) abstract away this complexity by allowing the model to natively output structured functions. This obviates the need to manually prompt + parse the outputs. \n",
    "\n",
    "LlamaIndex offers an abstraction called a `PydanticProgram` that integrates with a function endpoint to produce a structured Pydantic object. We integrate with OpenAI and Guidance."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ceb461a-5c64-4efe-b36b-f02bef414e9e",
   "metadata": {},
   "source": [
    "We redefine our `Answer` class with annotations, as well as an `Answers` class containing a list of answers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a62bb36-b64f-4900-a89c-ddfc39afbeac",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pydantic import Field\n",
    "\n",
    "\n",
    "class Answer(BaseModel):\n",
    "    \"Represents a single choice with a reason.\"\n",
    "    choice: int\n",
    "    reason: str\n",
    "\n",
    "\n",
    "class Answers(BaseModel):\n",
    "    \"\"\"Represents a list of answers.\"\"\"\n",
    "\n",
    "    answers: List[Answer]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c32cc480-4d10-4a67-90f0-ee8cc002a559",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'title': 'Answers',\n",
       " 'description': 'Represents a list of answers.',\n",
       " 'type': 'object',\n",
       " 'properties': {'answers': {'title': 'Answers',\n",
       "   'type': 'array',\n",
       "   'items': {'$ref': '#/definitions/Answer'}}},\n",
       " 'required': ['answers'],\n",
       " 'definitions': {'Answer': {'title': 'Answer',\n",
       "   'description': 'Represents a single choice with a reason.',\n",
       "   'type': 'object',\n",
       "   'properties': {'choice': {'title': 'Choice', 'type': 'integer'},\n",
       "    'reason': {'title': 'Reason', 'type': 'string'}},\n",
       "   'required': ['choice', 'reason']}}}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Answers.schema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c65d01c-b56e-49ae-89c6-268d29bb25c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.program import OpenAIPydanticProgram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d6f7879-b3f6-4f65-94df-e2ac2edc0a3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "router_prompt1 = router_prompt0.partial_format(\n",
    "    num_choices=len(choices),\n",
    "    max_outputs=len(choices),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acae0e8c-442e-43f5-9753-bd1afff12ae5",
   "metadata": {},
   "outputs": [],
   "source": [
    "program = OpenAIPydanticProgram.from_defaults(\n",
    "    output_cls=Answers,\n",
    "    prompt=router_prompt1,\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a44e079d-545c-4a75-abc2-409e8fd5c85e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Function call: Answers with args: {\n",
      "  \"answers\": [\n",
      "    {\n",
      "      \"choice\": 2,\n",
      "      \"reason\": \"Orange peels are related to oranges\"\n",
      "    }\n",
      "  ]\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "query_str = \"What are the health benefits of eating orange peels?\"\n",
    "output = program(context_list=choices_str, query_str=query_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d914bf0e-a1ec-49cf-adfd-3636c536908e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Answers(answers=[Answer(choice=2, reason='Orange peels are related to oranges')])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7bc1e1ce-6075-4655-9c6c-3882806da483",
   "metadata": {},
   "source": [
    "## 4. Plug Router Module as part of a RAG pipeline\n",
    "\n",
    "In this section we'll put the router module to use in a RAG pipeline. We'll use it to dynamically decide whether to perform question-answering or summarization. We can easily get a question-answering query engine using top-k retrieval through our vector index, while summarization is performed through our summary index. Each query engine is described as a \"choice\" to our router, and we compose the whole thing into a single query engine."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34382f7b-11ff-44ef-95a3-63eb4723f9b7",
   "metadata": {},
   "source": [
    "### Setup: Load Data\n",
    "\n",
    "We load the Llama 2 paper as data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1d16ec2-c293-4d58-850d-80bc90b01a38",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mkdir: data: File exists\n",
      "--2023-09-17 23:37:11--  https://arxiv.org/pdf/2307.09288.pdf\n",
      "Resolving arxiv.org (arxiv.org)... 128.84.21.199\n",
      "Connecting to arxiv.org (arxiv.org)|128.84.21.199|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 13661300 (13M) [application/pdf]\n",
      "Saving to: ‘data/llama2.pdf’\n",
      "\n",
      "data/llama2.pdf     100%[===================>]  13.03M  1.50MB/s    in 9.5s    \n",
      "\n",
      "2023-09-17 23:37:22 (1.37 MB/s) - ‘data/llama2.pdf’ saved [13661300/13661300]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!mkdir data\n",
    "!wget --user-agent \"Mozilla\" \"https://arxiv.org/pdf/2307.09288.pdf\" -O \"data/llama2.pdf\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcd9b401-1212-4c33-a18a-14f39ecb9989",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from llama_hub.file.pymu_pdf.base import PyMuPDFReader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e538688b-c578-4807-a78b-c0db5ad722ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "loader = PyMuPDFReader()\n",
    "documents = loader.load(file_path=\"./data/llama2.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce59ee6c-9197-4709-9208-a0e3450a641d",
   "metadata": {},
   "source": [
    "### Setup: Define Indexes\n",
    "\n",
    "Define both a vector index and summary index over this data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8615db4e-7572-4828-b55f-f846d653aa5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index import ServiceContext, VectorStoreIndex, SummaryIndex\n",
    "\n",
    "service_context = ServiceContext.from_defaults(chunk_size=1024)\n",
    "vector_index = VectorStoreIndex.from_documents(\n",
    "    documents, service_context=service_context\n",
    ")\n",
    "summary_index = SummaryIndex.from_documents(\n",
    "    documents, service_context=service_context\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23103269-6e9f-4e68-b964-4b7a1238a7f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "vector_query_engine = vector_index.as_query_engine()\n",
    "summary_query_engine = summary_index.as_query_engine()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85126015-7873-4540-a977-8ac351a0f427",
   "metadata": {},
   "source": [
    "### Define RouterQueryEngine\n",
    "\n",
    "We subclass our `CustomQueryEngine` to define a custom router."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cdf6b37-31de-4c23-9e85-9cbfb732b591",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.query_engine import CustomQueryEngine, BaseQueryEngine\n",
    "from llama_index.response_synthesizers import TreeSummarize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99ec5310-5385-48a6-863a-ced4f2aeb303",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RouterQueryEngine(CustomQueryEngine):\n",
    "    \"\"\"Use our Pydantic program to perform routing.\"\"\"\n",
    "\n",
    "    query_engines: List[BaseQueryEngine]\n",
    "    choice_descriptions: List[str]\n",
    "    verbose: bool = False\n",
    "    router_prompt: PromptTemplate\n",
    "    llm: OpenAI\n",
    "    summarizer: TreeSummarize = Field(default_factory=TreeSummarize)\n",
    "\n",
    "    def custom_query(self, query_str: str):\n",
    "        \"\"\"Define custom query.\"\"\"\n",
    "\n",
    "        program = OpenAIPydanticProgram.from_defaults(\n",
    "            output_cls=Answers,\n",
    "            prompt=router_prompt1,\n",
    "            verbose=self.verbose,\n",
    "            llm=self.llm,\n",
    "        )\n",
    "\n",
    "        choices_str = get_choice_str(self.choice_descriptions)\n",
    "        output = program(context_list=choices_str, query_str=query_str)\n",
    "        # print choice and reason, and query the underlying engine\n",
    "        if self.verbose:\n",
    "            print(f\"Selected choice(s):\")\n",
    "            for answer in output.answers:\n",
    "                print(f\"Choice: {answer.choice}, Reason: {answer.reason}\")\n",
    "\n",
    "        responses = []\n",
    "        for answer in output.answers:\n",
    "            choice_idx = answer.choice - 1\n",
    "            query_engine = self.query_engines[choice_idx]\n",
    "            response = query_engine.query(query_str)\n",
    "            responses.append(response)\n",
    "\n",
    "        # if a single choice is picked, we can just return that response\n",
    "        if len(responses) == 1:\n",
    "            return responses[0]\n",
    "        else:\n",
    "            # if multiple choices are picked, we can pick a summarizer\n",
    "            response_strs = [str(r) for r in responses]\n",
    "            result_response = self.summarizer.get_response(\n",
    "                query_str, response_strs\n",
    "            )\n",
    "            return result_response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "153701ce-cd1a-4318-b575-afc508e10f4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "choices = [\n",
    "    (\n",
    "        \"Useful for answering questions about specific sections of the Llama 2\"\n",
    "        \" paper\"\n",
    "    ),\n",
    "    \"Useful for questions that ask for a summary of the whole paper\",\n",
    "]\n",
    "\n",
    "router_query_engine = RouterQueryEngine(\n",
    "    query_engines=[vector_query_engine, summary_query_engine],\n",
    "    choice_descriptions=choices,\n",
    "    verbose=True,\n",
    "    router_prompt=router_prompt1,\n",
    "    llm=OpenAI(model=\"gpt-4\"),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "865f15fa-710c-41eb-a2a2-745a8e3c0ced",
   "metadata": {},
   "source": [
    "### Try our constructed Router Query Engine\n",
    "\n",
    "Let's take our self-built router query engine for a spin! We ask a question that routes to the vector query engine, and also another question that routes to the summarization engine."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7769a3ec-f6b6-44ec-acab-ed8a14a6b32b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Function call: Answers with args: {\n",
      "  \"answers\": [\n",
      "    {\n",
      "      \"choice\": 1,\n",
      "      \"reason\": \"This question is asking for specific information about the Llama 2 model and its comparison to GPT-4 in the experimental results. Therefore, the summary that is useful for answering questions about specific sections of the paper would be most relevant.\"\n",
      "    }\n",
      "  ]\n",
      "}\n",
      "Selected choice(s):\n",
      "Choice: 1, Reason: This question is asking for specific information about the Llama 2 model and its comparison to GPT-4 in the experimental results. Therefore, the summary that is useful for answering questions about specific sections of the paper would be most relevant.\n"
     ]
    }
   ],
   "source": [
    "response = router_query_engine.query(\n",
    "    \"How does the Llama 2 model compare to GPT-4 in the experimental results?\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27448520-0727-47e5-9888-eb4cdefbc890",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The Llama 2 model performs better than GPT-4 in the experimental results.\n"
     ]
    }
   ],
   "source": [
    "print(str(response))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bfc9df7-c48b-477e-a7ac-a0f8ca4a9e21",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Function call: Answers with args: {\n",
      "  \"answers\": [\n",
      "    {\n",
      "      \"choice\": 2,\n",
      "      \"reason\": \"This choice is directly related to providing a summary of the whole paper, which is what the question asks for.\"\n",
      "    }\n",
      "  ]\n",
      "}\n",
      "Selected choice(s):\n",
      "Choice: 2, Reason: This choice is directly related to providing a summary of the whole paper, which is what the question asks for.\n"
     ]
    }
   ],
   "source": [
    "response = router_query_engine.query(\"Can you give a summary of this paper?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab3c9756-81ca-4b4f-8ae1-0e37d9e1464b",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(str(response))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama_index_v2",
   "language": "python",
   "name": "llama_index_v2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
